import torch
import torch.nn as nn
import math 


class Reshape(nn.Module):
    '''
    Custom reshape module is required for the FC mapping
    Apparently there is no separate reshape module in PT
    PT community is not big on sq modules but this is what works for us for now
    '''
    def __init__(self, shape):
        super(Reshape, self).__init__()
        self.shape = shape # must be a list

    def __repr__(self):
        return ('Reshape({})'.format(self.shape))    
        
    def forward(self, x):
        self.bs = x.size(0)
        if isinstance(self.shape, int):
            return x.view(self.bs, self.shape)    
        return x.view(self.bs, *self.shape)






class BasicBlock(nn.Module):

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()

        # self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet_Cifar(nn.Module):

    # block expansion is 1

    def __init__(self, block=BasicBlock, layers=[3, 3, 3], num_classes=10):
        super(ResNet_Cifar, self).__init__()
        self.inplanes = 16
        
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)

        self.avgpool = nn.AvgPool2d(8, stride=1)

        self.fc = nn.Linear(64, num_classes)

        for m in self.modules():
        	# some initialization scheme...
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        
        return x


class ResNet20Prime(nn.Module):
    def __init__(self, input_size=[3, 32, 32], num_classes=10):
        super(ResNet20Prime, self).__init__()

        layers = []

        layers.append(nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(16))
        layers.append(nn.ReLU(inplace=True))
        
        # layer 1 with 3 basic blocks
        layers.append(nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(16))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(16))

        layers.append(nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(16))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(16))

        layers.append(nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(16))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(16, 16, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(16))

        # layer 2 with 3 basic blocks
        layers.append(nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(32))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(32))

        layers.append(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(32))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(32))

        layers.append(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(32))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(32))

        # layer 3 with 3 basic blocks
        layers.append(nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(64))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(64))

        layers.append(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(64))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(64))

        layers.append(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(64))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False))
        layers.append(nn.BatchNorm2d(64))

        layers.append(nn.AvgPool2d(8, stride=1))

        layers.append(Reshape(64)) # hard coded size!
        layers.append(nn.Linear(64, num_classes, bias=True))

        self.net = nn.Sequential(*layers)
        
        # downsample 1
        # nn.Conv2d(16, 32, kernel_size=1, stride=2, bias=False)
        # nn.BatchNorm2d(32)

        # downsample 2
        # nn.Conv2d(32, 64, kernel_size=1, stride=2, bias=False)
        # nn.BatchNorm2d(64)

    def forward(self, x):
        return self.net(x)

def resnet20(**kwargs):
    # model = ResNet_Cifar()
    model = ResNet20Prime()
    return model


if __name__ == '__main__':
    torch.manual_seed(0)
    net = resnet20()
    # print(net)
    y = net(torch.randn(1, 3, 32, 32))
    print(y.size(), y.norm())
    